package com.twelvet.gateway.filter;

import java.nio.charset.StandardCharsets;

import com.twelvet.gateway.config.properties.XssProperties;
import com.twelvet.framework.utils.StringUtils;
import com.twelvet.framework.utils.html.EscapeUtil;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import io.netty.buffer.ByteBufAllocator;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

/**
 * @author twelvet
 * @WebSite www.twelvet.cn
 * @Description: 跨站脚本过滤器
 */
@Component
@ConditionalOnProperty(value = "security.xss.enabled", havingValue = "true")
public class XssFilter implements GlobalFilter, Ordered {

	/**
	 * 跨站脚本的 xss 配置，nacos自行添加
	 */
	@Autowired
	private XssProperties xss;

	@Override
	public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
		ServerHttpRequest request = exchange.getRequest();
		// GET DELETE 不过滤
		HttpMethod method = request.getMethod();
		if (method == null || method.matches(HttpMethod.GET.name()) || method.matches(HttpMethod.DELETE.name())) {
			return chain.filter(exchange);
		}
		// 非json类型，不过滤
		if (!isJsonRequest(exchange)) {
			return chain.filter(exchange);
		}
		// excludeUrls 不过滤
		String url = request.getURI().getPath();
		if (StringUtils.matches(url, xss.getExcludeUrls())) {
			return chain.filter(exchange);
		}
		ServerHttpRequestDecorator httpRequestDecorator = requestDecorator(exchange);
		return chain.filter(exchange.mutate().request(httpRequestDecorator).build());

	}

	private ServerHttpRequestDecorator requestDecorator(ServerWebExchange exchange) {
		return new ServerHttpRequestDecorator(exchange.getRequest()) {
			@Override
			public Flux<DataBuffer> getBody() {
				Flux<DataBuffer> body = super.getBody();
				return body.map(dataBuffer -> {
					byte[] content = new byte[dataBuffer.readableByteCount()];
					dataBuffer.read(content);
					DataBufferUtils.release(dataBuffer);
					String bodyStr = new String(content, StandardCharsets.UTF_8);
					// 防xss攻击过滤
					bodyStr = EscapeUtil.clean(bodyStr);
					// 转成字节
					byte[] bytes = bodyStr.getBytes();
					NettyDataBufferFactory nettyDataBufferFactory = new NettyDataBufferFactory(
							ByteBufAllocator.DEFAULT);
					DataBuffer buffer = nettyDataBufferFactory.allocateBuffer(bytes.length);
					buffer.write(bytes);
					return buffer;
				});
			}

			@Override
			public HttpHeaders getHeaders() {
				HttpHeaders httpHeaders = new HttpHeaders();
				httpHeaders.putAll(super.getHeaders());
				// 由于修改了请求体的body，导致content-length长度不确定，因此需要删除原先的content-length
				httpHeaders.remove(HttpHeaders.CONTENT_LENGTH);
				httpHeaders.set(HttpHeaders.TRANSFER_ENCODING, "chunked");
				return httpHeaders;
			}

		};
	}

	/**
	 * 是否是Json请求
	 */
	public boolean isJsonRequest(ServerWebExchange exchange) {
		String header = exchange.getRequest().getHeaders().getFirst(HttpHeaders.CONTENT_TYPE);
		return StringUtils.startsWithIgnoreCase(header, MediaType.APPLICATION_JSON_VALUE);
	}

	@Override
	public int getOrder() {
		return -100;
	}

}
